"""
Phase 3: Mechanisms & Interactions — Sovereign Spreads
=======================================================
Tests:
  (a) Z × debt/GDP: does aging amplify debt sensitivity?
  (b) Z × KAOPEN: open capital accounts and demographic risk
  (c) Z × fiscal_bal: demographic erosion of fiscal space
  (d) Mediation: does Z affect spreads through fiscal or directly?
  (e) Dynamic models with lagged spread
  (f) Time subsamples (pre/post GFC)
  (g) Chow tests for structural breaks

Output: output/tables/interactions.md, mediation.md, structural_break.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
from scipy import stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/sovereign_spreads")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label):
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        print(f"  [{label}] {dep_var} missing — skipping")
        return None
    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)
    print(f"\n  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, "
          f"R²={gls.r_squared:.4f}")
    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared, 'rho': gls.rho,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
        sig = stars(gls.pvalues[i])
        print(f"    {name:<30} {gls.beta[i]:>10.4f} ({gls.se[i]:.4f}) {sig}")
    return results


def build_table(results, key_vars, notes, filename, title):
    if not results:
        return
    md = [f"# {title}\n"]
    md.append("| Model | Dep Var | N | Countries | R² |")
    md.append("|---|---|---|---|---|")
    for r in results:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                  f"| {r['n_countries']} | {r['r_squared']:.3f} |")
    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in results:
        for var in key_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                p = r[f'p_{var}']
                md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                          f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")
    md.append(f"\n*{notes}*")
    out = TABLES_DIR / filename
    out.write_text('\n'.join(md))
    print(f"\n  Saved: {out}")


def main():
    print("=" * 70)
    print("PHASE 3: Mechanisms & Interactions — Sovereign Spreads")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "spread_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls_rating = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen']
    controls_spread = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']

    # ═══════════════════════════════════════════════════════════════════
    # PART A: INTERACTION MODELS ON RATINGS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART A: INTERACTION MODELS ON RATINGS")
    print("=" * 70)

    int_results = []

    # A1: Z × debt/GDP → rating
    int_vars_debt = ['Z_1_x_debt', 'Z_2_x_debt', 'Z_3_x_debt']
    int_vars_debt = [v for v in int_vars_debt if v in df.columns]
    r = run_model(df, 'rating_numeric',
                  demo_vars + controls_rating + ['govt_debt_gdp'] + int_vars_debt,
                  "A1: Z×debt → rating")
    if r: int_results.append(r)

    # A2: Z × fiscal → rating
    int_vars_fiscal = ['Z_1_x_fiscal', 'Z_2_x_fiscal', 'Z_3_x_fiscal']
    int_vars_fiscal = [v for v in int_vars_fiscal if v in df.columns]
    r = run_model(df, 'rating_numeric',
                  demo_vars + controls_rating + int_vars_fiscal,
                  "A2: Z×fiscal → rating")
    if r: int_results.append(r)

    # A3: Z × KAOPEN → rating
    int_vars_ka = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
    int_vars_ka = [v for v in int_vars_ka if v in df.columns]
    r = run_model(df, 'rating_numeric',
                  demo_vars + controls_rating + int_vars_ka,
                  "A3: Z×KAOPEN → rating")
    if r: int_results.append(r)

    # A4: OECD Z×debt → rating
    oecd = df[df['iso3'].isin(OECD_38)].copy()
    r = run_model(oecd, 'rating_numeric',
                  demo_vars + controls_rating + ['govt_debt_gdp'] + int_vars_debt,
                  "A4: OECD Z×debt → rating")
    if r: int_results.append(r)

    key_int_vars = demo_vars + ['govt_debt_gdp'] + int_vars_debt + int_vars_fiscal + int_vars_ka
    build_table(int_results, key_int_vars,
                "Interaction models test whether demographic structure moderates "
                "the impact of fiscal/financial variables on ratings",
                "interactions_ratings.md",
                "Interaction Models: Demographics × Fiscal/Financial → Ratings")

    # ═══════════════════════════════════════════════════════════════════
    # PART B: INTERACTION MODELS ON SPREADS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART B: INTERACTION MODELS ON SPREADS")
    print("=" * 70)

    spread_int_results = []

    # B1: Z × debt → spread
    r = run_model(df, 'sovereign_spread',
                  demo_vars + controls_spread + ['govt_debt_gdp'] + int_vars_debt,
                  "B1: Z×debt → spread")
    if r: spread_int_results.append(r)

    # B2: Z × KAOPEN → spread
    r = run_model(df, 'sovereign_spread',
                  demo_vars + controls_spread + int_vars_ka,
                  "B2: Z×KAOPEN → spread")
    if r: spread_int_results.append(r)

    # B3: Z × fiscal → spread
    r = run_model(df, 'sovereign_spread',
                  demo_vars + controls_spread + int_vars_fiscal,
                  "B3: Z×fiscal → spread")
    if r: spread_int_results.append(r)

    build_table(spread_int_results,
                demo_vars + int_vars_debt + int_vars_ka + int_vars_fiscal,
                "Interaction models on sovereign spread (10y - world avg)",
                "interactions_spreads.md",
                "Interaction Models: Demographics × Fiscal/Financial → Spreads")

    # ═══════════════════════════════════════════════════════════════════
    # PART C: MEDIATION — Does Z work through fiscal or directly?
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART C: MEDIATION ANALYSIS")
    print("=" * 70)

    med_results = []

    # C1: Z → rating (baseline, no fiscal controls)
    r = run_model(df, 'rating_numeric', demo_vars + ['rgdp_growth', 'kaopen'],
                  "C1: Z → rating (no fiscal)")
    if r: med_results.append(r)

    # C2: Z → rating (add fiscal_bal)
    r = run_model(df, 'rating_numeric', demo_vars + controls_rating,
                  "C2: Z → rating (+ fiscal_bal)")
    if r: med_results.append(r)

    # C3: Z → rating (add debt/GDP)
    r = run_model(df, 'rating_numeric',
                  demo_vars + controls_rating + ['govt_debt_gdp'],
                  "C3: Z → rating (+ debt)")
    if r: med_results.append(r)

    # C4: Z → rating (add debt + exp_rev_gap)
    fiscal_full = ['govt_debt_gdp']
    if 'exp_rev_gap' in df.columns:
        fiscal_full.append('exp_rev_gap')
    r = run_model(df, 'rating_numeric',
                  demo_vars + controls_rating + fiscal_full,
                  "C4: Z → rating (+ debt + exp_rev)")
    if r: med_results.append(r)

    # C5: Z → rating (add primary_bal)
    if 'primary_bal_gdp' in df.columns:
        r = run_model(df, 'rating_numeric',
                      demo_vars + controls_rating + ['govt_debt_gdp', 'primary_bal_gdp'],
                      "C5: Z → rating (+ debt + primary)")
        if r: med_results.append(r)

    # Compute attenuation
    if len(med_results) >= 3:
        z1_base = med_results[0].get('coef_Z_1', None)
        z1_full = med_results[2].get('coef_Z_1', None)
        if z1_base and z1_full and z1_base != 0:
            attenuation = (1 - z1_full / z1_base) * 100
            print(f"\n  ★ Z₁ attenuation (no fiscal → + debt): {attenuation:.1f}%")
            if abs(attenuation) < 30:
                print("    → Demographics affect ratings DIRECTLY, not just through fiscal channel")
            else:
                print("    → Significant fiscal mediation")

    build_table(med_results, demo_vars + ['govt_debt_gdp', 'exp_rev_gap', 'primary_bal_gdp'],
                "Sequential addition of fiscal controls to test mediation",
                "mediation.md",
                "Mediation: Fiscal Channel vs Direct Demographic Effect on Ratings")

    # ═══════════════════════════════════════════════════════════════════
    # PART D: DYNAMIC MODELS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART D: DYNAMIC MODELS")
    print("=" * 70)

    dyn_results = []

    # D1: Dynamic rating (lagged rating + Z)
    r = run_model(df, 'rating_numeric',
                  ['rating_lag'] + demo_vars + controls_rating + ['govt_debt_gdp'],
                  "D1: dynamic rating")
    if r: dyn_results.append(r)

    # D2: Dynamic spread (lagged spread + Z)
    r = run_model(df, 'sovereign_spread',
                  ['sovereign_spread_lag'] + demo_vars + controls_spread,
                  "D2: dynamic spread")
    if r: dyn_results.append(r)

    # D3: Rating change (ΔR = R - R_lag)
    r = run_model(df, 'rating_change', demo_vars + controls_rating + ['govt_debt_gdp'],
                  "D3: Z → Δrating")
    if r: dyn_results.append(r)

    # D4: Downgrade probability
    r = run_model(df, 'downgrade_any', demo_vars + controls_rating + ['govt_debt_gdp'],
                  "D4: Z → P(downgrade)")
    if r: dyn_results.append(r)

    key_dyn = ['rating_lag', 'sovereign_spread_lag'] + demo_vars + ['govt_debt_gdp']
    build_table(dyn_results, key_dyn,
                "Dynamic models with lagged dependent variable",
                "dynamic_models.md",
                "Dynamic Models: Lagged DV + Demographics")

    # ═══════════════════════════════════════════════════════════════════
    # PART E: STRUCTURAL BREAKS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART E: STRUCTURAL BREAKS (PRE/POST GFC)")
    print("=" * 70)

    break_results = []

    # Pre-GFC (1990-2007)
    pre = df[df['year'] <= 2007].copy()
    r = run_model(pre, 'rating_numeric', demo_vars + controls_rating + ['govt_debt_gdp'],
                  "E1: pre-GFC rating")
    if r: break_results.append(r)

    # Post-GFC (2008-2024)
    post = df[df['year'] >= 2008].copy()
    r = run_model(post, 'rating_numeric', demo_vars + controls_rating + ['govt_debt_gdp'],
                  "E2: post-GFC rating")
    if r: break_results.append(r)

    # Pre-GFC spreads
    r = run_model(pre, 'sovereign_spread', demo_vars + controls_spread,
                  "E3: pre-GFC spread")
    if r: break_results.append(r)

    # Post-GFC spreads
    r = run_model(post, 'sovereign_spread', demo_vars + controls_spread,
                  "E4: post-GFC spread")
    if r: break_results.append(r)

    # Chow test for ratings
    print("\n  Chow test (rating model) ...")
    vars_chow = demo_vars + controls_rating + ['govt_debt_gdp']
    vars_chow = [v for v in vars_chow if v in df.columns]
    full = df.dropna(subset=['rating_numeric'] + vars_chow)
    pre_c = full[full['year'] <= 2007]
    post_c = full[full['year'] >= 2008]

    if len(pre_c) >= 50 and len(post_c) >= 50:
        # Pooled
        gls_full = PanelGLS()
        gls_full.fit(full['rating_numeric'].values, full[vars_chow].values,
                     full['iso3'].values, full['year'].values)
        rss_full = np.sum((full['rating_numeric'].values -
                           full[vars_chow].values @ gls_full.beta) ** 2)

        gls_pre = PanelGLS()
        gls_pre.fit(pre_c['rating_numeric'].values, pre_c[vars_chow].values,
                    pre_c['iso3'].values, pre_c['year'].values)
        rss_pre = np.sum((pre_c['rating_numeric'].values -
                          pre_c[vars_chow].values @ gls_pre.beta) ** 2)

        gls_post = PanelGLS()
        gls_post.fit(post_c['rating_numeric'].values, post_c[vars_chow].values,
                     post_c['iso3'].values, post_c['year'].values)
        rss_post = np.sum((post_c['rating_numeric'].values -
                           post_c[vars_chow].values @ gls_post.beta) ** 2)

        k = len(vars_chow)
        n = len(full)
        F_chow = ((rss_full - rss_pre - rss_post) / k) / ((rss_pre + rss_post) / (n - 2 * k))
        p_chow = 1 - stats.f.cdf(F_chow, k, n - 2 * k)
        print(f"    Chow F = {F_chow:.2f}, p = {p_chow:.4f}")
        if p_chow < 0.05:
            print("    → STRUCTURAL BREAK CONFIRMED")
        break_results.append({
            'label': 'Chow test', 'dep_var': 'rating_numeric',
            'n_obs': n, 'n_countries': full['iso3'].nunique(),
            'r_squared': F_chow, 'rho': p_chow,
            'coef_Z_1': F_chow, 'se_Z_1': p_chow, 'p_Z_1': p_chow,
        })

    build_table(break_results, demo_vars + ['govt_debt_gdp'],
                "Pre/post GFC subsamples. Last row: Chow F-stat (R² column) and p-value.",
                "structural_break.md",
                "Structural Break Analysis: Pre/Post GFC")

    print("\n" + "=" * 70)
    print("Phase 3 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
